{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Split Dataset\n",
    "\n",
    "This notebook shows how to split dataset into train, validation and test sub set."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read data\n",
    "Use numpy and pandas to read data file list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "np.random.seed(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tell pandas where the csv file is:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "csv_file_url = \"data/data.csv\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are total 221565 examples in this dataset.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>file_basename</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>300vw-112-000322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>300vw-402-000872</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>300vw-022-001242</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>300vw-143-000044</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>300vw-138-000175</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      file_basename\n",
       "0  300vw-112-000322\n",
       "1  300vw-402-000872\n",
       "2  300vw-022-001242\n",
       "3  300vw-143-000044\n",
       "4  300vw-138-000175"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_data = pd.read_csv(csv_file_url)\n",
    "total_file_number = len(full_data)\n",
    "print(\"There are total {} examples in this dataset.\".format(total_file_number))\n",
    "full_data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Split files\n",
    "There will be three groups: train, validation and test.\n",
    "\n",
    "Tell notebook number of samples for each group in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_train = 200000\n",
    "num_validation = 11565\n",
    "num_test = 10000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make sure there are enough example for your choice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looks good! 200000 for train, 11565 for validation and 10000 for test.\n"
     ]
    }
   ],
   "source": [
    "assert num_train + num_validation + num_test <= total_file_number, \"Not enough examples for your choice.\"\n",
    "print(\"Looks good! {} for train, {} for validation and {} for test.\".format(num_train, num_validation, num_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Random spliting files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "index_train = np.random.choice(total_file_number, size=num_train, replace=False)\n",
    "index_validation_test = np.setdiff1d(list(range(total_file_number)), index_train)\n",
    "index_validation = np.random.choice(index_validation_test, size=num_validation, replace=False)\n",
    "index_test = np.setdiff1d(index_validation_test, index_validation)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Merge them into sub datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = full_data.iloc[index_train]\n",
    "validation = full_data.iloc[index_validation]\n",
    "test = full_data.iloc[index_test]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Write to files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train.to_csv('data/data_train.csv', index=None)\n",
    "validation.to_csv(\"data/data_validation.csv\", index=None)\n",
    "test.to_csv('data/data_test.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All done!\n"
     ]
    }
   ],
   "source": [
    "print(\"All done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "widgets": {
   "state": {},
   "version": "1.1.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
